{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploying a MedNIST Classifier with BentoML\n",
    "\n",
    "This notebook demos the process of packaging up a trained model using BentoML into an artifact which can be run as a local program performing inference, a web service doing the same, and a Docker containerized web service. BentoML provides various ways of deploying models with existing platforms like AWS or Azure but we'll focus on local deployment here since researchers are more likely to do this. This tutorial will train a MedNIST classifier like the [MONAI tutorial here](../../2d_classification/mednist_tutorial.ipynb) and then do the packaging as described in this [BentoML tutorial](https://github.com/bentoml/gallery/blob/master/pytorch/fashion-mnist/pytorch-fashion-mnist.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n",
    "!pip install -q bentoml==0.13.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 1.3.2\n",
      "Numpy version: 1.26.4\n",
      "Pytorch version: 2.3.1+cu121\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e\n",
      "MONAI __file__: /home/<username>/anaconda3/envs/monai/lib/python3.9/site-packages/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.11\n",
      "ITK version: 5.4.0\n",
      "Nibabel version: 5.2.1\n",
      "scikit-image version: 0.24.0\n",
      "scipy version: 1.13.1\n",
      "Pillow version: 10.4.0\n",
      "Tensorboard version: 2.17.0\n",
      "gdown version: 5.2.0\n",
      "TorchVision version: 0.18.1+cu121\n",
      "tqdm version: 4.66.4\n",
      "lmdb version: 1.5.1\n",
      "psutil version: 6.0.0\n",
      "pandas version: 2.2.2\n",
      "einops version: 0.7.0\n",
      "transformers version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "mlflow version: 2.14.2\n",
      "pynrrd version: 1.0.0\n",
      "clearml version: NOT INSTALLED or UNKNOWN VERSION.\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import shutil\n",
    "import tempfile\n",
    "import glob\n",
    "import PIL.Image\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from ignite.engine import Events\n",
    "\n",
    "from monai.apps import download_and_extract\n",
    "from monai.config import print_config\n",
    "from monai.networks.nets import DenseNet121\n",
    "from monai.engines import SupervisedTrainer\n",
    "from monai.transforms import (\n",
    "    EnsureChannelFirst,\n",
    "    Compose,\n",
    "    LoadImage,\n",
    "    RandFlip,\n",
    "    RandRotate,\n",
    "    RandZoom,\n",
    "    ScaleIntensity,\n",
    "    EnsureType,\n",
    ")\n",
    "from monai.utils import set_determinism\n",
    "\n",
    "set_determinism(seed=0)\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "You can specify a directory with the MONAI_DATA_DIRECTORY environment variable.\n",
    "This allows you to save results and reuse downloads.\n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/chyang/Documents/monai_tutorials_gitee/data\n"
     ]
    }
   ],
   "source": [
    "os.environ['MONAI_DATA_DIRECTORY'] = '/home/chyang/Documents/monai_tutorials_gitee/data'\n",
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download dataset\n",
    "\n",
    "The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),\n",
    "[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),\n",
    "and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).\n",
    "\n",
    "The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)\n",
    "under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).\n",
    "\n",
    "If you use the MedNIST dataset, please acknowledge the source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\"\n",
    "md5 = \"0bc7306e7427e00ad1c5526a6677552d\"\n",
    "\n",
    "compressed_file = os.path.join(root_dir, \"MedNIST.tar.gz\")\n",
    "data_dir = os.path.join(root_dir, \"MedNIST\")\n",
    "if not os.path.exists(data_dir):\n",
    "    download_and_extract(resource, compressed_file, root_dir, md5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']\n",
      "Label counts: [10000, 8954, 10000, 10000, 10000, 10000]\n",
      "Total image count: 58954\n",
      "Image dimensions: 64 x 64\n"
     ]
    }
   ],
   "source": [
    "subdirs = sorted(glob.glob(f\"{data_dir}/*/\"))\n",
    "\n",
    "class_names = [os.path.basename(sd[:-1]) for sd in subdirs]\n",
    "image_files = [glob.glob(f\"{sb}/*\") for sb in subdirs]\n",
    "\n",
    "image_files_list = sum(image_files, [])\n",
    "image_class = sum(([i] * len(f) for i, f in enumerate(image_files)), [])\n",
    "image_width, image_height = PIL.Image.open(image_files_list[0]).size\n",
    "\n",
    "print(f\"Label names: {class_names}\")\n",
    "print(f\"Label counts: {list(map(len, image_files))}\")\n",
    "print(f\"Total image count: {len(image_class)}\")\n",
    "print(f\"Image dimensions: {image_width} x {image_height}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup and Train\n",
    "\n",
    "Here we'll create a transform sequence and train the network, omitting validation and testing since we know this does indeed work and it's not needed here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transforms = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        EnsureChannelFirst(),\n",
    "        ScaleIntensity(),\n",
    "        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),\n",
    "        RandFlip(spatial_axis=0, prob=0.5),\n",
    "        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),\n",
    "        EnsureType(),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MedNISTDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, image_files, labels, transforms):\n",
    "        self.image_files = image_files\n",
    "        self.labels = labels\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_files)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.transforms(self.image_files[index]), self.labels[index]\n",
    "\n",
    "\n",
    "# just one dataset and loader, we won't bother with validation or testing\n",
    "train_ds = MedNISTDataset(image_files_list, image_class, train_transforms)\n",
    "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(class_names)).to(device)\n",
    "loss_function = torch.nn.CrossEntropyLoss()\n",
    "opt = torch.optim.Adam(net.parameters(), 1e-5)\n",
    "max_epochs = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3 Loss: 0.17640653252601624\n",
      "Epoch 2/3 Loss: 0.07659219950437546\n",
      "Epoch 3/3 Loss: 0.06494573503732681\n"
     ]
    }
   ],
   "source": [
    "def _prepare_batch(batch, device, non_blocking):\n",
    "    return tuple(b.to(device) for b in batch)\n",
    "\n",
    "\n",
    "trainer = SupervisedTrainer(device, max_epochs, train_loader, net, opt, loss_function, prepare_batch=_prepare_batch)\n",
    "\n",
    "\n",
    "@trainer.on(Events.EPOCH_COMPLETED)\n",
    "def _print_loss(engine):\n",
    "    print(f\"Epoch {engine.state.epoch}/{engine.state.max_epochs} Loss: {engine.state.output[0]['loss']}\")\n",
    "\n",
    "\n",
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The network will be saved out here as a Torchscript object but this isn't necessary as we'll see later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.jit.script(net).save(\"classifier.zip\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BentoML Setup\n",
    "\n",
    "BentoML provides it's platform through an API to wrap service requests as method calls. This is obviously similar to how Flask works (which is one of the underlying technologies used here), but on top of this is provided various facilities for storing the network (artifacts), handling the IO component of requests, and caching data. What we need to provide is a script file to represent the services we want, BentoML will take this with the artifacts we provide and store this in a separate location which can be run locally as well as uploaded to a server (sort of like Docker registries). \n",
    "\n",
    "The script below will create our API which includes MONAI code. The transform sequence needs a special read Transform to turn a data stream into an image, but otherwise the code like what was used above for training. The network is stored as an artifact which in practice is the stored weights in the BentoML bundle. This is loaded at runtime automatically, but instead we could load the Torchscript model instead if we wanted to, in particular if we wanted an API that didn't rely on MONAI code. \n",
    "\n",
    "The script needs to be written out to a file first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting mednist_classifier_bentoml.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile mednist_classifier_bentoml.py\n",
    "\n",
    "from typing import BinaryIO, List\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import torch\n",
    "\n",
    "from monai.transforms import (\n",
    "    EnsureChannelFirst,\n",
    "    Compose,\n",
    "    Transform,\n",
    "    ScaleIntensity,\n",
    "    EnsureType,\n",
    ")\n",
    "\n",
    "import bentoml\n",
    "from bentoml.frameworks.pytorch import PytorchModelArtifact\n",
    "from bentoml.adapters import FileInput, JsonOutput\n",
    "from bentoml.utils import cached_property\n",
    "\n",
    "MEDNIST_CLASSES = [\"AbdomenCT\", \"BreastMRI\", \"CXR\", \"ChestCT\", \"Hand\", \"HeadCT\"]\n",
    "\n",
    "\n",
    "class LoadStreamPIL(Transform):\n",
    "    \"\"\"Load an image file from a data stream using PIL.\"\"\"\n",
    "\n",
    "    def __init__(self, mode=None):\n",
    "        self.mode = mode\n",
    "\n",
    "    def __call__(self, stream):\n",
    "        img = Image.open(stream)\n",
    "\n",
    "        if self.mode is not None:\n",
    "            img = img.convert(mode=self.mode)\n",
    "\n",
    "        return np.array(img)\n",
    "\n",
    "\n",
    "@bentoml.env(pip_packages=[\"torch\", \"numpy\", \"monai\", \"pillow\"])\n",
    "@bentoml.artifacts([PytorchModelArtifact(\"classifier\")])\n",
    "class MedNISTClassifier(bentoml.BentoService):\n",
    "    @cached_property\n",
    "    def transform(self):\n",
    "        return Compose([LoadStreamPIL(\"L\"), EnsureChannelFirst(channel_dim=\"no_channel\"), ScaleIntensity(), EnsureType()])\n",
    "\n",
    "    @bentoml.api(input=FileInput(), output=JsonOutput(), batch=True)\n",
    "    def predict(self, file_streams: List[BinaryIO]) -> List[str]:\n",
    "        img_tensors = list(map(self.transform, file_streams))\n",
    "        batch = torch.stack(img_tensors).float()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = self.artifacts.classifier(batch)\n",
    "        _, output_classes = outputs.max(dim=1)\n",
    "\n",
    "        return [MEDNIST_CLASSES[oc] for oc in output_classes]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the script is loaded and the classifier artifact is packed with the network's state. This is then saved to a repository directory on the local machine:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-07-18 14:39:11,114] WARNING - Python 3.9.19 found in current environment is not officially supported by BentoML. The docker base image used is'bentoml/model-server:0.13.1' which will use conda to install Python 3.9.19 in the build process. Supported Python versions are: f3.6, 3.7, 3.8\n",
      "[2024-07-18 14:39:11,115] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n",
      "[2024-07-18 14:39:12,661] INFO - BentoService bundle 'MedNISTClassifier:20240718143911_EFF9DD' saved to: /home/chyang/bentoml/repository/MedNISTClassifier/20240718143911_EFF9DD\n",
      "/home/chyang/bentoml/repository/MedNISTClassifier/20240718143911_EFF9DD\n"
     ]
    }
   ],
   "source": [
    "from mednist_classifier_bentoml import MedNISTClassifier  # noqa: E402\n",
    "\n",
    "bento_svc = MedNISTClassifier()\n",
    "bento_svc.pack(\"classifier\", net.cpu().eval())\n",
    "\n",
    "saved_path = bento_svc.save()\n",
    "\n",
    "print(saved_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can look at the contents of this repository, which includes code and setup scripts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "总计 48\n",
      "-rwxrw-r-- 1 chyang chyang 3170 Jul 18 14:39 bentoml-init.sh\n",
      "-rw-rw-r-- 1 chyang chyang  829 Jul 18 14:39 bentoml.yml\n",
      "-rwxrw-r-- 1 chyang chyang  841 Jul 18 14:39 docker-entrypoint.sh\n",
      "-rw-rw-r-- 1 chyang chyang 1593 Jul 18 14:39 Dockerfile\n",
      "-rw-rw-r-- 1 chyang chyang 3317 Jul 18 14:39 docs.json\n",
      "-rw-rw-r-- 1 chyang chyang   49 Jul 18 14:39 environment.yml\n",
      "-rw-rw-r-- 1 chyang chyang   72 Jul 18 14:39 MANIFEST.in\n",
      "drwxrwxr-x 4 chyang chyang 4096 Jul 18 14:39 MedNISTClassifier\n",
      "-rw-rw-r-- 1 chyang chyang    6 Jul 18 14:39 python_version\n",
      "-rw-rw-r-- 1 chyang chyang  298 Jul 18 14:39 README.md\n",
      "-rw-rw-r-- 1 chyang chyang   70 Jul 18 14:39 requirements.txt\n",
      "-rw-rw-r-- 1 chyang chyang 1691 Jul 18 14:39 setup.py\n"
     ]
    }
   ],
   "source": [
    "!ls -l {saved_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This repository can be run like a stored program where we invoke it by name and the API name (\"predict\") we want to use and provide the inputs as a file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-07-18 14:44:36,127] INFO - Getting latest version MedNISTClassifier:20240718143911_EFF9DD\n",
      "[2024-07-18 14:44:38,919] WARNING - Python 3.9.19 found in current environment is not officially supported by BentoML. The docker base image used is'bentoml/model-server:0.13.1' which will use conda to install Python 3.9.19 in the build process. Supported Python versions are: f3.6, 3.7, 3.8\n",
      "[2024-07-18 14:44:38,973] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n",
      "[2024-07-18 14:44:39,002] INFO - {'service_name': 'MedNISTClassifier', 'service_version': '20240718143911_EFF9DD', 'api': 'predict', 'task': {'data': {'uri': 'file:///home/chyang/Documents/monai_tutorials_gitee/data/MedNIST/AbdomenCT/004980.jpeg', 'name': '004980.jpeg'}, 'task_id': 'ddf5d954-6fa8-4938-aa09-971010d9efe9', 'cli_args': ('--input-file', '/home/chyang/Documents/monai_tutorials_gitee/data/MedNIST/AbdomenCT/004980.jpeg'), 'inference_job_args': {}}, 'result': {'data': '\"AbdomenCT\"', 'http_status': 200, 'http_headers': (('Content-Type', 'application/json'),)}, 'request_id': 'ddf5d954-6fa8-4938-aa09-971010d9efe9'}\n",
      "\"AbdomenCT\"\n"
     ]
    }
   ],
   "source": [
    "!bentoml run MedNISTClassifier:latest predict --input-file {image_files[0][0]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The service can also be run off of a Flask web server. The following script starts the service, waits for it to get going, uses curl to send the test file as a POST request to get a prediction, then kill the server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction: An error has occurred in BentoML user code when handling this request, find the error details in server logs\n"
     ]
    }
   ],
   "source": [
    "%%bash -s {image_files[0][0]}\n",
    "# filename passed in as an argument to the cell\n",
    "test_file=$1\n",
    "\n",
    "# start the Flask-based server, sending output to /dev/null for neatness\n",
    "bentoml serve --port=8001 MedNISTClassifier:latest &> /dev/null &\n",
    "\n",
    "# recall the PID of the server and wait for it to start\n",
    "lastpid=$!\n",
    "sleep 5\n",
    "\n",
    "# send the test file using curl and capture the returned string\n",
    "result=$(curl -s -X POST \"http://127.0.0.1:8001/predict\" -F image=@$test_file)\n",
    "# kill the server\n",
    "kill $lastpid\n",
    "\n",
    "echo \"Prediction: $result\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The service can be packaged as a Docker container to be started elsewhere as a server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-07-18 14:40:10,739] INFO - Getting latest version MedNISTClassifier:20240718143911_EFF9DD\n",
      "\u001b[39mFound Bento: /home/chyang/bentoml/repository/MedNISTClassifier/20240718143911_EFF9DD\u001b[0m\n",
      "Containerizing MedNISTClassifier:20240718143911_EFF9DD with local YataiService and docker daemon from local environment\\\u001b[32m\n",
      "Build container image: mednist-classifier:latest\u001b[0m\n",
      "\b \r"
     ]
    }
   ],
   "source": [
    "!bentoml containerize MedNISTClassifier:latest -t mednist-classifier:latest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "REPOSITORY   TAG       IMAGE ID       CREATED        SIZE\n",
      "ubuntu       latest    de52d803b224   2 months ago   76.2MB\n"
     ]
    }
   ],
   "source": [
    "!docker image ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
